Fix failing default transform for LKJCorr#7065
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #7065 +/- ##
=======================================
Coverage 92.19% 92.20%
=======================================
Files 101 101
Lines 16893 16901 +8
=======================================
+ Hits 15575 15584 +9
+ Misses 1318 1317 -1
|
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
LKJCorr
|
Is it good on your end? Asking because it's marked as a draft |
|
@ricardoV94 With the new suggestion the test fails locally with the NotImplementedError: Univariate transform MultivariateIntervalTransform cannot be applied to multivariate lkjcorr_rv{1, (0, 0), floatX, False}🤔 |
Where is that check coming from? We might need to add some meta-info to the Transform |
|
The problem is the logp of the distribution is incorrectly implemented. It's returning a scalar instead of a vector of |
|
Yeah! It is failing locally. It's good that you caught up on this with the test! Do you think there is an "easy" fix? |
|
We should add a You can parametrize the test to have two cases pymc/tests/distributions/test_multivariate.py Lines 1048 to 1055 in 9b4bf2a Also this test shouldn't be in the |
It's not trivial, it requires thinking carefully about batch dimensions, like we did in this PR: #6897 |
|
Also could you reintroduce the change from the other PR where we always run this check instead of being in the else branch? pymc/pymc/logprob/transform_value.py Lines 128 to 133 in 04a03b5 |
|
Added the suggested changes :) |
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
|
We should add a check in the logp similar to this: pymc/pymc/distributions/multivariate.py Lines 1249 to 1252 in e67a317 Should be a NotImplementedError though |
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
|
We now have two tests failing because of the new tests/distributions/test_multivariate.py::TestMoments::test_lkjcorr_moment[3-1-1-expected2] FAILED [ 87%]
tests/distributions/test_multivariate.py::TestMoments::test_lkjcorr_moment[5-1-size3-expected3] FAILED [ 87%] |
Yup |
|
@ricardoV94 we are back to 🟢 :) |
|
Thank you for all your help @ricardoV94 ❤️ |
|
Thanks @juanitorduz |
Closes #7002
Wt take a different direction from #7023
📚 Documentation preview 📚: https://pymc--7065.org.readthedocs.build/en/7065/